{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "z9ln8ylCKHsx"
   },
   "source": [
    "# L3-A - Linear Quantization II: Symmetric vs. Asymmetric Mode\n",
    "\n",
    "In this lesson, you will learn a different way of performing linear quantization, Symmetric Mode."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The libraries are already installed in the classroom.  If you're running this notebook on your own machine, you can install the following:\n",
    "\n",
    "```Python\n",
    "!pip install torch==2.1.1\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "import torch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Linear Quantization: Symmetric Mode"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- Implement a function which returns the `scale` for Linear Quantization in Symmetric Mode."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 114
   },
   "outputs": [],
   "source": [
    "def get_q_scale_symmetric(tensor, dtype=torch.int8):\n",
    "    r_max = tensor.abs().max().item()\n",
    "    q_max = torch.iinfo(dtype).max\n",
    "\n",
    "    # return the scale\n",
    "    return r_max/q_max"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "### test the implementation on a 4x4 matrix\n",
    "test_tensor = torch.randn((4, 4))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Note:** Since the values are random, what you see in the video might be different than what you will get."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "test_tensor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "get_q_scale_symmetric(test_tensor)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- Perform Linear Quantization in Symmetric Mode.\n",
    "- `linear_q_with_scale_and_zero_point` is the same function you implemented in the previous lesson."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "from helper import linear_q_with_scale_and_zero_point"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 199
   },
   "outputs": [],
   "source": [
    "def linear_q_symmetric(tensor, dtype=torch.int8):\n",
    "    scale = get_q_scale_symmetric(tensor)\n",
    "    \n",
    "    quantized_tensor = linear_q_with_scale_and_zero_point(tensor,\n",
    "                                                     scale=scale,\n",
    "                   # in symmetric quantization zero point is = 0    \n",
    "                                                    zero_point=0,\n",
    "                                                      dtype=dtype)\n",
    "    \n",
    "    return quantized_tensor, scale"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "quantized_tensor, scale = linear_q_symmetric(test_tensor)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Dequantization\n",
    "\n",
    "- Perform Dequantization\n",
    "- Plot the Quantization error.\n",
    "- `linear_dequantization` is the same function you implemented in the previous lesson."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 63,
    "id": "FkGcWpzn4_xr"
   },
   "outputs": [],
   "source": [
    "from helper import linear_dequantization, plot_quantization_errors\n",
    "from helper import quantization_error"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "dequantized_tensor = linear_dequantization(quantized_tensor,scale,0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "plot_quantization_errors(\n",
    "    test_tensor, quantized_tensor, dequantized_tensor)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "print(f\"\"\"Quantization Error : \\\n",
    "{quantization_error(test_tensor, dequantized_tensor)}\"\"\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [
    "NgmjMISuIyzF",
    "3dMVgqcNJCqE",
    "MFx2m7RmzRd5",
    "l6XbkmOzYMrC",
    "8NS1TnQt6E6v"
   ],
   "provenance": [
    {
     "file_id": "12_pQW6LB80u98m72_YKwM4ph7eb5Uf3g",
     "timestamp": 1705360205453
    },
    {
     "file_id": "1U9pm4j_uAD8EO7OrEPvdpwFjQKO3suuH",
     "timestamp": 1700237940920
    }
   ]
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
